import sys, os, time
sys.path.append("../FinRL")
sys.path.append("./")
from ruamel.yaml import YAML
from utils import system

import gym
import numpy as np 
import torch
import matplotlib; matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import envs
from common.sac import ReplayBuffer, SAC
from utils.plots.train_plot import plot_sac_curve
import datetime

from finrl.meta.preprocessor.yahoodownloader import YahooDownloader
from finrl.meta.preprocessor.preprocessors import FeatureEngineer, data_split
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
from finrl.agents.stablebaselines3.models import DRLAgent
from stable_baselines3.common.logger import configure
from finrl.meta.data_processor import DataProcessor

from finrl.plot import backtest_stats, backtest_plot, get_daily_return, get_baseline
from pprint import pprint
import itertools
from finrl import config
from finrl import config_tickers
import os
from finrl.main import check_and_make_directories
from finrl.config import (
    DATA_SAVE_DIR,
    TRAINED_MODEL_DIR,
    TENSORBOARD_LOG_DIR,
    RESULTS_DIR,
    INDICATORS,
)
check_and_make_directories([DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR])

def train_policy(env):
    
    replay_buffer = ReplayBuffer(
        env.observation_space.shape[0], 
        env.action_space.shape[0],
        device=device,
        size=v['sac']['buffer_size'])
    
    sac_agent = SAC(env, replay_buffer,
        steps_per_epoch=env_T,
        update_after=env_T * v['sac']['random_explore_episodes'], 
        max_ep_len=env_T,
        seed=seed,
        start_steps=env_T * v['sac']['random_explore_episodes'],
        device=device,
        **v['sac']
        )
    assert sac_agent.reinitialize == True

    sac_agent.test_fn = sac_agent.test_agent_ori_env
    sac_test_rets, sac_alphas, sac_log_pis, sac_time_steps = sac_agent.learn_mujoco(print_out=True)

    return sac_agent.get_action

def evaluate_policy(policy, env, n_episodes, deterministic=False):
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]
    expert_states = torch.zeros((n_episodes, env_T, obs_dim)) # s1 to sT
    expert_actions = torch.zeros((n_episodes, env_T, act_dim)) # a0 to aT-1

    returns = []

    for n in range(n_episodes):
        obs = env.reset()
        ret = 0
        for t in range(env_T):
            action = policy(obs, deterministic)
            obs, rew, _, _ = env.step([action]) # NOTE: assume rew=0 after done=True for evaluation
            expert_states[n, t, :] = torch.from_numpy(obs).clone()
            expert_actions[n, t, :] = torch.from_numpy(action).clone()
            ret += rew
        returns.append(ret)
    
    return expert_states, expert_actions, np.array(returns)

def setup_training_env():
    TRAIN_START_DATE = '2010-01-01'
    TRAIN_END_DATE = '2021-10-01'
    TRADE_START_DATE = '2021-10-01'
    TRADE_END_DATE = '2023-03-01'

    df = YahooDownloader(start_date = TRAIN_START_DATE, end_date = TRADE_END_DATE, ticker_list = config_tickers.DOW_30_TICKER).fetch_data()


    fe = FeatureEngineer(use_technical_indicator=True, tech_indicator_list = INDICATORS, use_vix=True, use_turbulence=True, user_defined_feature = False)

    processed = fe.preprocess_data(df)

    list_ticker = processed["tic"].unique().tolist()
    list_date = list(pd.date_range(processed['date'].min(),processed['date'].max()).astype(str))
    combination = list(itertools.product(list_date,list_ticker))

    processed_full = pd.DataFrame(combination,columns=["date","tic"]).merge(processed,on=["date","tic"],how="left")
    processed_full = processed_full[processed_full['date'].isin(processed['date'])]
    processed_full = processed_full.sort_values(['date','tic'])

    processed_full = processed_full.fillna(0)

    #mvo_df = processed_full.sort_values(['date','tic'],ignore_index=True)[['date','tic','close']]

    TRAIN_START_DATE = '2021-01-01'
    TRAIN_END_DATE = '2022-01-01'
    TRADE_START_DATE = '2022-01-01'
    TRADE_END_DATE = '2023-01-01'

    train = data_split(processed_full, TRAIN_START_DATE,TRAIN_END_DATE)
    trade = data_split(processed_full, TRADE_START_DATE,TRADE_END_DATE)

    stock_dimension = len(train.tic.unique())
    state_space = 1 + 2*stock_dimension + len(INDICATORS)*stock_dimension
    buy_cost_list = sell_cost_list = [0.001] * stock_dimension
    num_stock_shares = [0] * stock_dimension
    env_kwargs = {
        "hmax": 100,
        "initial_amount": 1000,
        "num_stock_shares": num_stock_shares,
        "buy_cost_pct": buy_cost_list,
        "sell_cost_pct": sell_cost_list,
        "state_space": state_space,
        "stock_dim": stock_dimension,
        "tech_indicator_list": INDICATORS,
        "action_space": stock_dimension,
        "reward_scaling": 1
    }
    e_train_gym = StockTradingEnv(df = train, **env_kwargs)
    env_train, _ = e_train_gym.get_sb_env()
    return env_train 


if __name__ == "__main__":
    yaml = YAML()
    v = yaml.load(open(sys.argv[1]))

    # common parameters
    env_name, env_T = v['env']['env_name'], v['env']['T']
    seed = v['seed']

    # system: device, threads, seed, pid
    device = torch.device(f"cuda:{v['cuda']}" if torch.cuda.is_available() and v['cuda'] >= 0 else "cpu")
    torch.set_num_threads(1)
    np.set_printoptions(precision=3, suppress=True)
    system.reproduce(seed)

    env=setup_training_env()

    print(f"training Expert on {env_name}")
    policy = train_policy(env)

    expert_states_sto, expert_actions_sto, expert_returns = evaluate_policy(policy, env, v['expert']['samples_episode'])
    return_info = f'Expert(Sto) Return Avg: {expert_returns.mean():.2f}, std: {expert_returns.std():.2f}'
    print(return_info)

    log_txt = open(f"expert_data/meta/{env_name}_{seed}.txt", 'w')
    log_txt.write(return_info + '\n')
    log_txt.write(repr(expert_returns)+'\n')

    expert_states_det, expert_actions_det, expert_returns = evaluate_policy(policy, env, v['expert']['samples_episode'], True)
    return_info = f'Expert(Det) Return Avg: {expert_returns.mean():.2f}, std: {expert_returns.std():.2f}'
    print(return_info)
    log_txt.write(return_info + '\n')
    log_txt.write(repr(expert_returns)+'\n')

    log_txt.write(repr(v))

    os.makedirs('expert_data/states/', exist_ok=True)
    os.makedirs('expert_data/actions/', exist_ok=True)
    torch.save(expert_states_sto, f'expert_data/states/{env_name}_{seed}_sto.pt')
    torch.save(expert_states_det, f'expert_data/states/{env_name}_{seed}_det.pt')
    torch.save(expert_actions_sto, f'expert_data/actions/{env_name}_{seed}_sto.pt')
    torch.save(expert_actions_det, f'expert_data/actions/{env_name}_{seed}_det.pt')